#----------------------------------------------
# -*- encoding=utf-8 -*-                      #
# __author__:'xiaojie'                        #
# CreateTime:                                 #
#       2019/5/21 11:37                       #
#                                             #
#               天下风云出我辈，                 #
#               一入江湖岁月催。                 #
#               皇图霸业谈笑中，                 #
#               不胜人生一场醉。                 #
#----------------------------------------------

import time
from .ops import *
from .utils import *
from tensorflow.contrib.data.python.ops.batching import map_and_batch
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
# from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device

class SAGAN():

    def __init__(self,sess,args):
        self.model_name = "SAGAN"  # name for checkpoint
        self.sess = sess
        self.dataset_name = args.dataset
        self.checkpoint_dir = args.checkpoint_dir
        self.sample_dir = args.sample_dir
        self.result_dir = args.result_dir
        self.log_dir = args.log_dir

        self.epoch = args.epoch
        self.iteration = args.iteration
        self.batch_size = args.batch_size
        self.print_freq = args.print_freq
        self.save_freq = args.save_freq
        self.img_size = args.img_size

        """ Generator """
        self.layer_num = int(np.log2(self.img_size)) - 3
        self.z_dim = args.z_dim  # dimension of noise-vector
        self.up_sample = args.up_sample
        self.gan_type = args.gan_type

        """ Discriminator """
        self.n_critic = args.n_critic
        self.sn = args.sn
        self.ld = args.ld

        self.sample_num = args.sample_num  # number of generated images to be saved
        self.test_num = args.test_num

        # train
        self.g_learning_rate = args.g_lr
        self.d_learning_rate = args.d_lr
        self.beta1 = args.beta1
        self.beta2 = args.beta2

        self.custom_dataset = False

        if self.dataset_name == 'mnist':
            self.c_dim = 1
            self.data = load_mnist(size=self.img_size)
        else:
            self.c_dim = 3
            self.data = load_data(dataset_name=self.dataset_name, size=self.img_size)
            self.custom_dataset = False

        self.dataset_num = len(self.data)

        self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
        check_folder(self.sample_dir)

        print("##### Information #####")
        print("# gan type : ", self.gan_type)
        print("# dataset : ", self.dataset_name)
        print("# dataset number : ", self.dataset_num)
        print("# batch_size : ", self.batch_size)
        print("# epoch : ", self.epoch)
        print("# iteration per epoch : ", self.iteration)

        print("##### Generator #####")
        print("# generator layer : ", self.layer_num)
        print("# upsample conv : ", self.up_sample)

        print("##### Discriminator #####")
        print("# discriminator layer : ", self.layer_num)
        print("# the number of critic : ", self.n_critic)
        print("# spectral normalization : ", self.sn)

    ##################################################################################
    # Generator
    ##################################################################################
    def generator(self,z,is_training=True,reuse=False):
        with tf.variable_scope('generator',reuse=reuse):
            ch = 1024
            x=deconv(z,channels=ch,kernel=4,stride=1,padding='VALID',
                     use_bias=False,sn=self.sn,scope='deconv')
            x=batch_norm(x,is_training,scope='batch_norm')
            x=relu(x)

            for i in range(self.layer_num//2):
                if self.up_sample:
                    x=up_sample(x,scale_factor=2)
                    x=conv(x,channels=ch//2,kernel=3,stride=1,pad=1,
                           sn=self.sn,scope='up_conv'+str(i))
                    x = batch_norm(x,is_training,scope='batch_norm_'+str(i))
                    x=relu(x)
                else:
                    x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn,
                               scope='deconv_' + str(i))
                    x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
                    x = relu(x)
                ch = ch//2

            # self Attention
            x = self.attention(x,ch,sn=self.sn,scope='attention',reuse=reuse)

            for i in range(self.layer_num//2,self.layer_num):
                if self.up_sample:
                    x = up_sample(x, scale_factor=2)
                    x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))
                    x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
                    x = relu(x)
                else:
                    x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn,
                               scope='deconv_' + str(i))
                    x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
                    x = relu(x)
                ch = ch//2

            if self.up_sample:
                x = up_sample(x, scale_factor=2)
                x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, sn=self.sn, scope='G_conv_logit')
                x = tanh(x)

            else:
                x = deconv(x, channels=self.c_dim, kernel=4, stride=2, use_bias=False, sn=self.sn,
                           scope='G_deconv_logit')
                x = tanh(x)

            return x

    ##################################################################################
    # Discriminator
    ##################################################################################
    def discriminator(self,x,is_training=True,reuse=False):
        with tf.variable_scope('discriminator',reuse=reuse):
            ch=64
            print('before',x.shape.as_list()) #  [32, 128, 128, 1]
            x=conv(x,channels=ch,kernel=4,stride=2,pad=1,sn=self.sn,use_bias=False,scope='conv')
            print('after',x.shape.as_list()) # [32, 64, 64, 64]
            x=lrelu(x,0.2)

            for i in range(self.layer_num//2):
                x=conv(x,channels=ch*2,kernel=4,stride=2,pad=1,
                       sn=self.sn,use_bias=False,scope='conv_'+str(i))
                x = batch_norm(x,is_training,scope='batch_norm_'+str(i))
                x=lrelu(x,0.2)

                ch = ch *2

            # Self Attention
            x = self.attention(x,ch,sn=self.sn,scope='attention',reuse=reuse)

            for i in range(self.layer_num//2,self.layer_num):
                x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False,
                         scope='conv_' + str(i))
                x = batch_norm(x, is_training, scope='batch_norm' + str(i))
                x = lrelu(x, 0.2)

                ch = ch * 2

            x = conv(x,channels=4,stride=1,sn=self.sn,use_bias=False,scope='D_logit')
            print('BBBBBBBBBBBBB',x.shape.as_list())
            return x

    def attention(self,x,ch,sn=False,scope='attention',reuse=False):
        print('CCCCCCCCCCC', ch) # x.shape [32, 16, 16, 256]
        with tf.variable_scope(scope,reuse=reuse):
            f = conv(x,ch//8,kernel=1,stride=1,sn=sn,scope='f_conv') # [bs,h,w,c][32, 16, 16, 32]
            g = conv(x,ch//8,kernel=1,stride=1,sn=sn,scope='g_conv') # [bs,h,w,c]
            h = conv(x,ch,kernel=1,stride=1,sn=sn,scope='h_conv') # [bs,h,w,c][32, 16, 16, 256]
            print('f.shape',f.shape.as_list(),h.shape.as_list())

            # N=h*w
            #print('DDDDDDDDDDDDDD',hw_flatten(g).shape.as_list())
            #print('DDDDDDDDDDDDDDE', hw_flatten(f).shape.as_list())
            s = tf.matmul(hw_flatten(g),hw_flatten(f),transpose_b=True) # [bs,N,N]

            beta = tf.nn.softmax(s) # attention map

            o = tf.matmul(beta,hw_flatten(h)) # [bs,N,C]
            gamma = tf.get_variable('gamma',[1],initializer=tf.constant_initializer(0.0))

            o = tf.reshape(o,shape=x.shape) # [bs,h,w,c] [32, 16, 16, 256]
            #print('RRRRRRRRRR',o.shape.as_list(),x.shape.as_list())
            x= gamma*o +x

        return x

    def gradient_penalty(self,real,fake):
        if self.gan_type == 'dragan':
            shape = tf.shape(real)
            eps = tf.random_uniform(shape=shape, minval=0., maxval=1.)
            x_mean, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
            x_std = tf.sqrt(x_var)  # magnitude of noise decides the size of local region
            noise = 0.5 * x_std * eps  # delta in paper

            # Author suggested U[0,1] in original paper, but he admitted it is bug in github
            # (https://github.com/kodalinaveen3/DRAGAN). It should be two-sided.

            alpha = tf.random_uniform(shape=[shape[0], 1, 1, 1], minval=-1., maxval=1.)
            interpolated = tf.clip_by_value(real + alpha * noise, -1., 1.)  # x_hat should be in the space of X

        else:
            alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
            interpolated = alpha * real + (1. - alpha) * fake

        logit = self.discriminator(interpolated, reuse=True)

        grad = tf.gradients(logit, interpolated)[0]  # gradient of D(interpolated)
        grad_norm = tf.norm(flatten(grad), axis=1)  # l2 norm

        GP = 0

        # WGAN - LP
        if self.gan_type == 'wgan-lp':
            GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))

        elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
            GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))

        return GP

    ##################################################################################
    # Model
    ##################################################################################
    def build_model(self):
        """ Graph Input """
        # images
        if self.custom_dataset:
            Image_Data_Class = ImageData(self.img_size, self.c_dim)
            inputs = tf.data.Dataset.from_tensor_slices(self.data)

            gpu_device = '/gpu:0'
            inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16,
                              drop_remainder=True))#.apply(prefetch_to_device(gpu_device, self.batch_size))

            inputs_iterator = inputs.make_one_shot_iterator()

            self.inputs = inputs_iterator.get_next()

        else:
            self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim],
                                         name='real_images')

        # noises
        self.z = tf.placeholder(tf.float32,[self.batch_size,1,1,self.z_dim],name='z')

        """Loss Function"""
        # output of D for real images
        real_logits = self.discriminator(self.inputs)
        print('AAAAAAAAAAA',self.inputs.shape.as_list())

        # output of D for fake images
        fake_images = self.generator(self.z)
        fake_logits = self.discriminator(fake_images, reuse=True)

        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
        else:
            GP = 0

        # get loss for discriminator
        self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP

        # get loss for generator
        self.g_loss = generator_loss(self.gan_type, fake=fake_logits)

        """ Training """
        # divide trainable variables into a group for D and a group for G
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'discriminator' in var.name]
        g_vars = [var for var in t_vars if 'generator' in var.name]

        # optimizers
        self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1,
                                              beta2=self.beta2).minimize(
                                                    self.d_loss, var_list=d_vars)

        self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1,
                                              beta2=self.beta2).minimize(
                                                    self.g_loss, var_list=g_vars)

        """" Testing """
        # for test
        self.fake_images = self.generator(self.z, is_training=False, reuse=True)

        """ Summary """
        self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.summary.scalar("g_loss", self.g_loss)


    ##################################################################################
    # Train
    ##################################################################################
    def train(self):
        # initialize all variables
        tf.global_variables_initializer().run()

        # graph inputs for visualize training results
        self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))

        # saver to save model
        self.saver = tf.train.Saver()

        # summary writer
        self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)

        # restore check-point if it exits
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            start_epoch = (int)(checkpoint_counter / self.iteration)
            start_batch_id = checkpoint_counter - start_epoch * self.iteration
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            start_epoch = 0
            start_batch_id = 0
            counter = 1
            print(" [!] Load failed...")

        # loop for epoch
        start_time = time.time()
        past_g_loss = -1.
        for epoch in range(start_epoch, self.epoch):
            # get batch data
            for idx in range(start_batch_id, self.iteration):
                batch_z = np.random.uniform(-1, 1, [self.batch_size, 1, 1, self.z_dim])

                if self.custom_dataset:

                    train_feed_dict = {
                        self.z: batch_z
                    }

                else:
                    random_index = np.random.choice(self.dataset_num, size=self.batch_size, replace=False)
                    # batch_images = self.data[idx*self.batch_size : (idx+1)*self.batch_size]
                    batch_images = self.data[random_index]

                    train_feed_dict = {
                        self.inputs: batch_images,
                        self.z: batch_z
                    }

                # update D network
                _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, h],
                                                       feed_dict=train_feed_dict)
                self.writer.add_summary(summary_str, counter)

                # update G network
                g_loss = None
                if (counter - 1) % self.n_critic == 0:
                    _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss],
                                                           feed_dict=train_feed_dict)
                    self.writer.add_summary(summary_str, counter)
                    past_g_loss = g_loss

                # display training status
                counter += 1
                if g_loss == None:
                    g_loss = past_g_loss
                print("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
                      % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))

                # save training results for every 300 steps
                if np.mod(idx + 1, self.print_freq) == 0:
                    samples = self.sess.run(self.fake_images, feed_dict={self.z: self.sample_z})
                    tot_num_samples = min(self.sample_num, self.batch_size)
                    manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
                    manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
                    save_images(samples[:manifold_h * manifold_w, :, :, :],
                                [manifold_h, manifold_w],
                                './' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format(
                                    epoch, idx + 1))

                if np.mod(idx + 1, self.save_freq) == 0:
                    self.save(self.checkpoint_dir, counter)

            # After an epoch, start_batch_id is set to zero
            # non-zero value is only for the first epoch after loading pre-trained model
            start_batch_id = 0

            # save model
            self.save(self.checkpoint_dir, counter)

            # show temporal results
            # self.visualize_results(epoch)

        # save model for final step
        self.save(self.checkpoint_dir, counter)

    @property
    def model_dir(self):
        return "{}_{}_{}_{}_{}_{}".format(
            self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, self.sn)

    def save(self, checkpoint_dir, step):
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        self.saver.save(self.sess,
                        os.path.join(checkpoint_dir, self.model_name + '.model'),
                        global_step=step)

    def load(self, checkpoint_dir):
        import re
        print(" [*] Reading checkpoints...")
        checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)

        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
            self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
            counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
            print(" [*] Success to read {}".format(ckpt_name))
            return True, counter
        else:
            print(" [*] Failed to find a checkpoint")
            return False, 0

    def visualize_results(self, epoch):
        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        """ random condition, random noise """

        z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))

        samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})

        save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
                    [image_frame_dim, image_frame_dim],
                    self.sample_dir + '/' + self.model_name + '_epoch%02d' \
                        % epoch + '_visualize.png')

    def test(self):
        tf.global_variables_initializer().run()

        self.saver = tf.train.Saver()
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        result_dir = os.path.join(self.result_dir, self.model_dir)
        check_folder(result_dir)

        if could_load:
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        tot_num_samples = min(self.sample_num, self.batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        """ random condition, random noise """

        for i in range(self.test_num):
            z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))

            samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})

            save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
                        [image_frame_dim, image_frame_dim],
                        result_dir + '/' + self.model_name + '_test_{}.png'.format(i))


